{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train a character-level GPT on some text data\n",
    "\n",
    "The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some Shakespeare, which we'll get it to predict character-level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up logging\n",
    "import logging\n",
    "logging.basicConfig(\n",
    "        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n",
    "        datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "        level=logging.INFO,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make deterministic\n",
    "from mingpt.utils import set_seed\n",
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting dataset.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile dataset.py\n",
    "\n",
    "import math\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "class CharDataset(Dataset):\n",
    "\n",
    "    def __init__(self, data, block_size):\n",
    "        chars = sorted(list(set(data)))\n",
    "        data_size, vocab_size = len(data), len(chars)\n",
    "        print('data has %d characters, %d unique.' % (data_size, vocab_size))\n",
    "        \n",
    "        self.stoi = { ch:i for i,ch in enumerate(chars) }\n",
    "        self.itos = { i:ch for i,ch in enumerate(chars) }\n",
    "        self.block_size = block_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.data = data\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data) - self.block_size\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # grab a chunk of (block_size + 1) characters from the data\n",
    "        chunk = self.data[idx:idx + self.block_size + 1]\n",
    "        # encode every character to an integer\n",
    "        dix = [self.stoi[s] for s in chunk]\n",
    "        \"\"\"\n",
    "        arrange data and targets so that the first i elements of x\n",
    "        will be asked to predict the i-th element of y. Notice that\n",
    "        the eventual language model will actually make block_size\n",
    "        individual predictions at the same time based on this data,\n",
    "        so we are being clever and amortizing the cost of the forward\n",
    "        pass of the network. So for example if block_size is 4, then\n",
    "        we could e.g. sample a chunk of text \"hello\", the integers in\n",
    "        x will correspond to \"hell\" and in y will be \"ello\". This will\n",
    "        then actually \"multitask\" 4 separate examples at the same time\n",
    "        in the language model:\n",
    "        - given just \"h\", please predict \"e\" as next\n",
    "        - given \"he\" please predict \"l\" next\n",
    "        - given \"hel\" predict \"l\" next\n",
    "        - given \"hell\" predict \"o\" next\n",
    "        \n",
    "        In addition, because the DataLoader will create batches of examples,\n",
    "        every forward/backward pass during traning will simultaneously train\n",
    "        a LOT of predictions, amortizing a lot of computation. In particular,\n",
    "        for a batched input of integers X (B, T) where B is batch size and\n",
    "        T is block_size and Y (B, T), the network will during training be\n",
    "        simultaneously training to make B*T predictions, all at once! Of course,\n",
    "        at test time we can paralellize across batch B, but unlike during training\n",
    "        we cannot parallelize across the time dimension T - we have to run\n",
    "        a forward pass of the network to recover the next single character of the \n",
    "        sequence along each batch dimension, and repeatedly always feed in a next\n",
    "        character to get the next one.\n",
    "        \n",
    "        So yes there is a big asymmetry between train/test time of autoregressive\n",
    "        models. During training we can go B*T at a time with every forward pass,\n",
    "        but during test time we can only go B at a time, T times, with T forward \n",
    "        passes.\n",
    "        \"\"\"\n",
    "        x = torch.tensor(dix[:-1], dtype=torch.long)\n",
    "        y = torch.tensor(dix[1:], dtype=torch.long)\n",
    "        return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "block_size = 128 # spatial extent of the model for its context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 1089k  100 1089k    0     0   543k      0  0:00:02  0:00:02 --:--:--  543k\n"
     ]
    }
   ],
   "source": [
    "!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o input.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data has 1115394 characters, 65 unique.\n"
     ]
    }
   ],
   "source": [
    "# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt\n",
    "from dataset import CharDataset\n",
    "text = open('input.txt', 'r').read() # don't worry we won't run out of file handles\n",
    "train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "01/05/2025 12:21:20 - INFO - mingpt.model -   number of parameters: 6.384640e+06\n"
     ]
    }
   ],
   "source": [
    "from mingpt.model import GPT, GPTConfig\n",
    "mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,\n",
    "                  n_layer=8, n_head=4, n_embd=256)\n",
    "model = GPT(mconf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1 iter 2178: train loss 0.96921. lr 3.000169e-04: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 2179/2179 [3:58:31<00:00,  6.57s/it]\n",
      "epoch 2 iter 2178: train loss 0.74673. lr 6.000000e-05: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 2179/2179 [4:00:49<00:00,  6.63s/it]\n"
     ]
    }
   ],
   "source": [
    "from mingpt.trainer import Trainer, TrainerConfig\n",
    "\n",
    "# initialize a trainer instance and kick off training\n",
    "tconf = TrainerConfig(max_epochs=2, batch_size=512, learning_rate=6e-4,\n",
    "                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,\n",
    "                      num_workers=4)\n",
    "trainer = Trainer(model, train_dataset, None, tconf)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "O God, O God! fall dead, I have of my poor head!\n",
      "\n",
      "DUCHESS OF YORK:\n",
      "Thou takest of needful birth midst be, newly kiss'd.\n",
      "In that doubt he of this slanderous world,\n",
      "Shamed to his soon and the throne belly,\n",
      "That she is the fairies' milks of nature.\n",
      "\n",
      "BRUTUS:\n",
      "She's a disease;\n",
      "Honourable Menenius, his beauty's death.\n",
      "\n",
      "SICINIUS:\n",
      "This a caution couple, I never show\n",
      "The people's voices, which they are devised.\n",
      "There is no great prodigatical\n",
      "To your great enemy; either for your own part, you rather\n",
      "With the devil close of your own report\n",
      "But that you dissention, with all that very eyes\n",
      "For why I repent; express you'll absence,\n",
      "But follow me, to the purpose of York,\n",
      "And spoke of his own rhand: I hear him speak it\n",
      "As it would he had that gallant of the neck\n",
      "And peril our loving coats with the haste.\n",
      "\n",
      "First Servingman:\n",
      "But how come into my command?\n",
      "\n",
      "Third Servingman:\n",
      "What, are you married? I will rend to Romeo?\n",
      "\n",
      "BRUTUS:\n",
      "Ay, sir, in me;\n",
      "A part of its strong wind and scarries shuns,\n",
      "Thou wilt blame you not purchase your guest\n",
      "Upon my bad, and misbraced my son Angelo.\n",
      "\n",
      "DUKE VINCENTIO:\n",
      "Not yet?\n",
      "\n",
      "Provost:\n",
      "Provost:\n",
      "Is it your will means to bed; but what comes here?\n",
      "\n",
      "DUKE VINCENTIO:\n",
      "My brother was your daughter\n",
      "And leave your thief; and, if he be, sir,\n",
      "You have heard your turn, and that have poison'd\n",
      "And in my sight fortune shall make you kill my heart.\n",
      "\n",
      "LUCIO:\n",
      "And that's your disposition of your sister.\n",
      "\n",
      "ISABELLA:\n",
      "Some hour that he will say somethings had been sick,\n",
      "And bring him not, he commixture himself\n",
      "After the curses she with oaths and effect;\n",
      "As bitter are she breaks of, and gave the louring clouds,\n",
      "Say to them here and stand all to come:\n",
      "When it first, it is a sound thirty thousand strong.\n",
      "\n",
      "Servant:\n",
      "Please your part, there you may pack it back.\n",
      "\n",
      "DUKE VINCENTIO:\n",
      "The time is gone of Claudio, and the meaner proud.\n",
      "\n",
      "ESCALUS:\n",
      "My lord of Lord Northumberland, his son?\n",
      "\n",
      "DUKE VINCENTIO:\n",
      "His name is Lucentio hither, no.\n",
      "\n",
      "LUCIO:\n",
      "As with him, my lord.\n",
      "\n",
      "LUCIO:\n",
      "I beseech you, my lord.\n",
      "\n",
      "DUKE VINCENTIO:\n",
      "W\n"
     ]
    }
   ],
   "source": [
    "# alright, let's sample some character-level Shakespeare\n",
    "from mingpt.utils import sample\n",
    "\n",
    "context = \"O God, O God!\"\n",
    "x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)\n",
    "y = sample(model, x, 2000, temperature=1.0, sample=True, top_k=10)[0]\n",
    "completion = ''.join([train_dataset.itos[int(i)] for i in y])\n",
    "print(completion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# well that was fun"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
